#include <stdlib.h>
#include <stdio.h>
#include <float.h>
#include <math.h>
#include <gsl/gsl_linalg.h>
#include <gsl/gsl_matrix_double.h>
#include <gsl/gsl_vector_double.h>
#include <assert.h>

#include "svd.h"

int wzero;

int main(int argc, char * argv[])
{
	const unsigned int M = 10;
	const unsigned int N = 3;
	const unsigned int R = 2;
	gsl_matrix * A;
	gsl_vector * S;
	gsl_vector * sigma;
	gsl_matrix_complex * P;
	gsl_matrix * Ut;
	gsl_matrix * Vt;
	unsigned int i, j;
	
	wzero = 0;
	A = gsl_matrix_alloc(M, N);
	S = gsl_vector_alloc(M);
	sigma = gsl_vector_alloc(M);
	P = gsl_matrix_complex_alloc(R, R);
	Ut = gsl_matrix_alloc(M, M);
	Vt = gsl_matrix_alloc(N, N);

	gsl_matrix_set(A, 0, 0, 3.0);
	gsl_matrix_set(A, 0, 1, 4.0);
	gsl_matrix_set(A, 0, 2, 1.0);
	gsl_matrix_set(A, 1, 0, 8.0);
	gsl_matrix_set(A, 1, 1, 1.0);
	gsl_matrix_set(A, 1, 2, 9.0);
	gsl_matrix_set(A, 2, 0, 8.0);
	gsl_matrix_set(A, 2, 1, 7.0);
	gsl_matrix_set(A, 2, 2, 7.0);
	gsl_matrix_set(A, 3, 0, 2.0);
	gsl_matrix_set(A, 3, 1, 2.0);
	gsl_matrix_set(A, 3, 2, 2.0);
	gsl_matrix_set(A, 4, 0, 4.0);
	gsl_matrix_set(A, 4, 1, 6.0);
	gsl_matrix_set(A, 4, 2, 2.0);
	gsl_matrix_set(A, 5, 0, 9.0);
	gsl_matrix_set(A, 5, 1, 1.0);
	gsl_matrix_set(A, 5, 2, 6.0);
	gsl_matrix_set(A, 6, 0, 9.0);
	gsl_matrix_set(A, 6, 1, 6.0);
	gsl_matrix_set(A, 6, 2, 9.0);
	gsl_matrix_set(A, 7, 0, 2.0);
	gsl_matrix_set(A, 7, 1, 4.0);
	gsl_matrix_set(A, 7, 2, 8.0);
	gsl_matrix_set(A, 8, 0, 8.0);
	gsl_matrix_set(A, 8, 1, 6.0);
	gsl_matrix_set(A, 8, 2, 3.0);
	gsl_matrix_set(A, 9, 0, 8.0);
	gsl_matrix_set(A, 9, 1, 2.0);
	gsl_matrix_set(A, 9, 2, 5.0);

	print_matrix("A_original", A);

	dosvd(A, S, P, sigma, Ut, Vt);

	printf("SVD computation complete\n");

	print_vector("sigma", sigma);
	
	print_matrix("A", A);
	print_matrix("Ut", Ut);
	print_matrix("Vt", Vt);

	printf("wzero = %d\n", wzero);

	gsl_matrix* U = gsl_matrix_alloc(M, M);
	gsl_matrix_transpose_memcpy(U, Ut);

	print_matrix("U", U);

	// Convert sigma vector into a Sigma matrix, by placing values
	// down the diagonal
	gsl_matrix* Sigma = gsl_matrix_alloc(M, N);
	gsl_matrix_set_zero(Sigma);
	j = 0;
	for (i = 0; i < sigma->size; i++)
	{
		double sigma_current = gsl_vector_get(sigma, i);
		if (fabs(sigma_current) > DBL_EPSILON)
		{
			gsl_matrix_set(Sigma, j, j, sigma_current);
			j++;
		}
	}

	print_matrix("Sigma", Sigma);

	// Want: U * Sigma * Vt
	gsl_matrix* tempProduct = gsl_matrix_alloc(M, N);
	gsl_matrix* productMatrix = gsl_matrix_alloc(M, N);
	gsl_linalg_matmult(U, Sigma, tempProduct);
	gsl_linalg_matmult(tempProduct, Vt, productMatrix);

	print_matrix("U * Sigma * Vt", productMatrix);

	gsl_matrix_free(A);
	gsl_vector_free(S);
	gsl_vector_free(sigma);
	gsl_matrix_complex_free(P);
	gsl_matrix_free(Ut);
	gsl_matrix_free(Vt);
	
	return 0;
}

void print_matrix(const char* name, const gsl_matrix* matrix)
{
	int numRows = matrix->size1;
	int numCols = matrix->size2;
	int i = 0, j = 0;
	double tmpDbl;

	printf("Matrix %s(%d,%d) = \n[", name, numRows, numCols);

	for (i = 0; i < numRows; i++)
	{
		for (j = 0; j < numCols; j++)
		{
			tmpDbl = gsl_matrix_get(matrix, i, j);
			printf("%10.4f", fabs(tmpDbl) > DBL_EPSILON ? tmpDbl : 0.0);
		}

		if (i < numRows - 1)
		{
			printf(";\n");
		}
	}
	printf("]\n");
}

void print_vector(const char* name, const gsl_vector* vector)
{
	int numDims = vector->size;
	int i = 0;
	double tmpDbl;

	printf("Vector %s(%d) = \n[", name, numDims);

	for (i = 0; i < numDims; i++)
	{
		tmpDbl = gsl_vector_get(vector, i);
		printf("%10.4f", fabs(tmpDbl) > DBL_EPSILON ? tmpDbl : 0.0);

		if (i < numDims - 1)
		{
			printf("\n");
		}
	}
	printf("]\n");
}

void dosvd(gsl_matrix * A, gsl_vector * S, gsl_matrix_complex * P, gsl_vector * sigma, gsl_matrix * Ut, gsl_matrix * Vt)
{
	unsigned int M, N, R, i, j, k;
	unsigned int lbi, ubi, lbj, ubj;
	double g, dblTemp, dblTemp2;
	double SSum;
	double delta;
	int converged;
	gsl_complex tempComplex;
	gsl_matrix * B;
	gsl_vector_view Vtrow;
	
	M = A->size1;
	N = A->size2;
	R = P->size1;
	gsl_vector_set_all(sigma, 0.0);
	
	gsl_matrix* A_original = gsl_matrix_alloc(A->size1, A->size2);
	gsl_matrix_memcpy(A_original, A);

	SSum = 0.0;
	for (i = 0; i < M; i++)
	{
		dblTemp = dotproduct(gsl_matrix_row(A, i), gsl_matrix_row(A, i));
		gsl_vector_set(S, i, dblTemp);
		SSum += dblTemp;
	}
	delta = DBL_EPSILON * SSum;

	gsl_matrix_set_identity(Ut);

	do
	{
		converged = 1;
		for (lbi = 0; lbi < M; lbi += R)
		{
			ubi = (lbi + R < M) ? (lbi + R) : M;
			for (lbj = lbi; lbj < M; lbj += R)
			{
				ubj = (lbj + R < M) ? (lbj + R) : M;
				for (i = lbi; i < ubi; i++)
				{
					gsl_vector_set(S, i, dotproduct(gsl_matrix_row(A, i), gsl_matrix_row(A, i)));
				}
				for (j = lbj; j < ubj; j++)
				{
					gsl_vector_set(S, j, dotproduct(gsl_matrix_row(A, j), gsl_matrix_row(A, j)));
				}
				for (i = lbi; i < ubi; i++)
				{
					for (j = lbj; j < ubj; j++)
					{
						if (i < j)
						{
							g = dotproduct(gsl_matrix_row(A, i), gsl_matrix_row(A, j));
							if (fabs(g) > delta)
							{
								converged = 0;
							}

							if (fabs(g) > DBL_EPSILON)
							{
								tempComplex = jacobi(gsl_vector_get(S, i), gsl_vector_get(S, j), g);
							}
							else
							{
								GSL_SET_COMPLEX(&tempComplex, 1.0, 0.0);
							}
							gsl_matrix_complex_set(P, i % R, j % R, tempComplex);
						}
					}
				}
				for (i = lbi; i < ubi; i++)
				{
					for (j = lbj; j < ubj; j++)
					{
						if (i < j)
						{
							tempComplex = gsl_matrix_complex_get(P, i % R, j % R);
							for (k = 0; k < N; k++)
							{
								dblTemp = GSL_REAL(tempComplex) * gsl_matrix_get(A, i, k) - GSL_IMAG(tempComplex) * gsl_matrix_get(A, j, k);
								dblTemp2 = GSL_IMAG(tempComplex) * gsl_matrix_get(A, i, k) + GSL_REAL(tempComplex) * gsl_matrix_get(A, j, k);
								gsl_matrix_set(A, i, k, dblTemp);
								gsl_matrix_set(A, j, k, dblTemp2);
							}
							for (k = 0; k < M; k++)
							{
								dblTemp = GSL_REAL(tempComplex) * gsl_matrix_get(Ut, i, k) - GSL_IMAG(tempComplex) * gsl_matrix_get(Ut, j, k);
								dblTemp2 = GSL_IMAG(tempComplex) * gsl_matrix_get(Ut, i, k) + GSL_REAL(tempComplex) * gsl_matrix_get(Ut, j, k);
								gsl_matrix_set(Ut, i, k, dblTemp);
								gsl_matrix_set(Ut, j, k, dblTemp2);
							}
						}
					}
				}
			} 
		}
	}
	while (!converged);
	
	B = gsl_matrix_alloc(M, N);
	gsl_linalg_matmult(Ut, A_original, B);

	// Note: sigmas at this point are not ordered by magnitude
	// Also less important, Ut's rows are out of order

	j = 0;

	for (i = 0; i < M; i++)
	{
		dblTemp = norm(gsl_matrix_row(B, i));
		dblTemp2 = norm(gsl_matrix_row(A, i));
		gsl_vector_set(sigma, i, fabs(dblTemp2) > DBL_EPSILON ? dblTemp2 : 0.0);

		if (fabs(dblTemp2) > DBL_EPSILON)
		{
			Vtrow = gsl_matrix_row(B, i);
			gsl_vector_scale(&Vtrow.vector, 1.0 / dblTemp);
			gsl_matrix_set_row(Vt, j, &Vtrow.vector);
			j++;
		}
	}

	// Note that Vt is now in the expected order, but sigma and Ut are not:
	// sigma should be ordered from greatest to least magnitude
	// Any permutations done to the dimensions of sigma to obtain this order,  
	// should also be applied to the rows of Ut 

	// Get the list of sigmas in sorted order, along with their original indexes
	gsl_vector* sigma_sorted_indices = gsl_vector_alloc(M);
	get_sorted_vector_order(sigma_sorted_indices, sigma);

	gsl_vector* sigma_sorted = gsl_vector_alloc(sigma->size);
	gsl_matrix* Ut_sorted = gsl_matrix_alloc(Ut->size1, Ut->size2);

	// Apply permutations to elements of sigma and rows of Ut
	int targetIndex;
	for (i = 0; i < sigma_sorted_indices->size; i++)
	{
		targetIndex = (int)gsl_vector_get(sigma_sorted_indices, i);

		gsl_vector_set(sigma_sorted, i, gsl_vector_get(sigma, targetIndex));

		gsl_vector_const_view row = gsl_matrix_const_row(Ut, targetIndex);
		gsl_matrix_set_row(Ut_sorted, i, &(row.vector) );
	}

	gsl_vector_swap(sigma, sigma_sorted);
	gsl_matrix_swap(Ut, Ut_sorted);

	// Okay, SVD result should be in its expected form now

	// TODO: Make sure to free all the stuff we allocated
	gsl_matrix_free(B);
	gsl_matrix_free(Ut_sorted);
	gsl_vector_free(sigma_sorted);
	gsl_vector_free(sigma_sorted_indices);
}

// Find the sorted order of data_vect, and set index_vect to be this permutation of data_vect's indices
void get_sorted_vector_order(gsl_vector* index_vect, const gsl_vector* data_vect)
{
	// Do a simple sort for now, don't worry too much about speed
	// Make a copy of data_vect, keep finding the largest element, and note its index
	// But don't actually sort the data
	// Set that element's value to -infinity after we identify it, then find the next largest element

	size_t i;
	assert(index_vect->size == data_vect->size);

	gsl_vector* data_vect_copy = gsl_vector_alloc(data_vect->size);
	gsl_vector_memcpy (data_vect_copy, data_vect);

	size_t largest_index;

	for (i = 0; i < data_vect->size; i++)
	{
		largest_index = gsl_vector_max_index(data_vect_copy);
		gsl_vector_set(index_vect, i, largest_index);
		gsl_vector_set(data_vect_copy, largest_index, -HUGE_VAL);
	}

	gsl_vector_free(data_vect_copy);
}

double dotproduct(gsl_vector_view A, gsl_vector_view B)
{
	double dp;
	int i;
	
	dp = 0.0;
	for (i = 0; i < A.vector.size; i++)
	{
		dp += gsl_vector_get(&A.vector, i) * gsl_vector_get(&B.vector, i);
	}
	return dp;
}

gsl_complex jacobi(double a, double b, double g)
{
	double w, t, sgnw, fabsw, c;
	gsl_complex retval;
	
	w = (b - a)/(2 * g);
	if (w == 0.0)
	{
		printf("W is zero!!\n");
		wzero = 1;
	}
	sgnw = w > 0.0 ? 1.0 : (w < 0.0 ? -1.0 : 0.0);
	fabsw = (sgnw < 0.0 ? (-1.0) * w : w);
	t = sgnw / (fabsw + sqrt(1.0 + w*w));
	c = 1.0 / sqrt(1.0 + t * t);
	GSL_SET_REAL(&retval, c);
	GSL_SET_IMAG(&retval, t * c);
	return retval;
}

double norm(gsl_vector_view A)
{
	return sqrt(dotproduct(A, A));
}
